{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e6325c27",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2021 NVIDIA Corporation. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# =============================================================================="
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ec0754c",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
    "\n",
    "# Movie Synopsis Feature Extraction with Bart text summarization\n",
    "\n",
    "In this notebook, will will make use of the BART [model](https://huggingface.co/transformers/model_doc/bart.html) to extract features from movie synopsis. \n",
    "\n",
    "Note: this notebook should be executed from within the below container:\n",
    "\n",
    "```\n",
    "docker pull huggingface/transformers-pytorch-gpu\n",
    "docker run --gpus=all  --rm -it --net=host -v $PWD:/workspace --ipc=host huggingface/transformers-pytorch-gpu \n",
    "```\n",
    "\n",
    "Then from within the container:\n",
    "```\n",
    "cd /workspace\n",
    "pip install jupyter jupyterlab\n",
    "jupyter server extension disable nbclassic\n",
    "jupyter-lab --allow-root --ip='0.0.0.0' --NotebookApp.token='admin'\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a58d7b6",
   "metadata": {},
   "source": [
    "First, we install some extra package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fdbbb3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install imdbpy\n",
    "\n",
    "# Cuda 11 and A100 support\n",
    "!pip3 install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "96af710c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'status': 'ok', 'restart': True}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import IPython\n",
    "\n",
    "IPython.Application.instance().kernel.do_shutdown(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da44429e",
   "metadata": {},
   "source": [
    "## Download pretrained BART model\n",
    "\n",
    "First, we download a pretrained BART model from HuggingFace library."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "585ac7c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import BartTokenizer, BartModel\n",
    "import torch\n",
    "\n",
    "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')\n",
    "model = BartModel.from_pretrained('facebook/bart-large').cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a42c13ac",
   "metadata": {},
   "source": [
    "## Extracting embeddings for all movie's synopsis\n",
    "\n",
    "We will use the average hidden state of the last decoder layer as text feature, comprising 1024 float values.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c3393aac",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open('movies_info.pkl', 'rb') as f:\n",
    "    movies_infos = pickle.load(f)['movies_infos']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d3cfd702",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 62423/62423 [43:41<00:00, 23.81it/s]  \n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "embeddings = {}\n",
    "for movie, movie_info in tqdm(movies_infos.items()):\n",
    "    synopsis = None\n",
    "    synopsis = movie_info.get('synopsis')\n",
    "    if synopsis is None:\n",
    "        plots = movie_info.get('plot')\n",
    "        if plots is not None:\n",
    "            synopsis = plots[0]\n",
    "    \n",
    "    if synopsis is not None:\n",
    "        inputs = tokenizer(synopsis, return_tensors=\"pt\", truncation=True, max_length=1024).to('cuda')\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**inputs, output_hidden_states=True)\n",
    "        embeddings[movie] = outputs.last_hidden_state.cpu().detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "90323589",
   "metadata": {},
   "outputs": [],
   "source": [
    "average_embeddings = {}\n",
    "for movie in embeddings:\n",
    "    average_embeddings[movie] = np.mean(embeddings[movie].squeeze(), axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "30e0125b",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('movies_synopsis_embeddings-1024.pkl', 'wb') as f:\n",
    "    pickle.dump({\"embeddings\": average_embeddings}, f, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5ab5a64",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
